use hybrid_array::{Array, ArraySize};
use typenum::Unsigned;

use crate::wots::WotsSig;
use crate::{PkSeed, SkSeed};
use crate::{address, wots::WotsParams};
use core::fmt::Debug;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct XmssSig<P: XmssParams> {
    pub(crate) sig: WotsSig<P>,
    pub(crate) auth: Array<Array<u8, P::N>, P::HPrime>,
}

impl<P: XmssParams> XmssSig<P> {
    pub const SIZE: usize = WotsSig::<P>::SIZE + P::HPrime::USIZE * P::N::USIZE;

    pub fn write_to(&self, buf: &mut [u8]) {
        debug_assert!(buf.len() == Self::SIZE, "Xmss serialize length mismatch");

        let (wots, auth) = buf.split_at_mut(WotsSig::<P>::SIZE);
        self.sig.write_to(wots);
        auth.chunks_exact_mut(P::N::USIZE)
            .zip(self.auth.iter())
            .for_each(|(buf, auth)| buf.copy_from_slice(auth.as_slice()));
    }

    #[cfg(feature = "alloc")]
    #[cfg(test)]
    pub fn to_vec(&self) -> Vec<u8> {
        let mut buf = vec![0u8; Self::SIZE];
        self.write_to(&mut buf);
        buf
    }
}

impl<P: XmssParams> TryFrom<&[u8]> for XmssSig<P> {
    // TODO: Real error
    type Error = ();

    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        if value.len() != Self::SIZE {
            return Err(());
        }
        let sig = WotsSig::<P>::try_from(&value[..WotsSig::<P>::SIZE])?;
        let mut auth = Array::<Array<u8, P::N>, P::HPrime>::default();
        for i in 0..P::HPrime::USIZE {
            auth[i].copy_from_slice(
                &value[WotsSig::<P>::SIZE + i * P::N::USIZE
                    ..WotsSig::<P>::SIZE + (i + 1) * P::N::USIZE],
            );
        }
        Ok(XmssSig { sig, auth })
    }
}

pub(crate) trait XmssParams: WotsParams + Sized {
    type HPrime: ArraySize + Debug + Eq;

    fn xmss_node(
        sk_seed: &SkSeed<Self::N>,
        node: u32,
        height: u32,
        pk_seed: &PkSeed<Self::N>,
        adrs: &address::WotsHash,
    ) -> Array<u8, Self::N> {
        debug_assert!(height <= Self::HPrime::U32);
        debug_assert!(node < (1 << (Self::HPrime::U32 - height)));
        if height == 0 {
            let mut adrs = adrs.clone();
            adrs.key_pair_adrs.set(node);
            Self::wots_pk_gen(sk_seed, pk_seed, &adrs)
        } else {
            let lnode = Self::xmss_node(sk_seed, 2 * node, height - 1, pk_seed, adrs);
            let rnode = Self::xmss_node(sk_seed, 2 * node + 1, height - 1, pk_seed, adrs);
            let mut adrs = adrs.tree_adrs();
            adrs.tree_height.set(height);
            adrs.tree_index.set(node);
            Self::h(pk_seed, &adrs, &lnode, &rnode)
        }
    }

    fn xmss_sign(
        m: &Array<u8, Self::N>,
        sk_seed: &SkSeed<Self::N>,
        pk_seed: &PkSeed<Self::N>,
        idx: u32,
        adrs: &address::WotsHash,
    ) -> XmssSig<Self> {
        let mut adrs = adrs.clone();
        adrs.key_pair_adrs.set(idx);

        let sig = Self::wots_sign(m, sk_seed, pk_seed, &adrs);

        let mut auth = Array::<Array<u8, Self::N>, Self::HPrime>::default();
        let mut idx = idx;
        for j in 0..Self::HPrime::U32 {
            let node = Self::xmss_node(sk_seed, idx ^ 1, j, pk_seed, &adrs);
            idx >>= 1;
            auth[j as usize] = node;
        }

        XmssSig { sig, auth }
    }

    fn xmss_pk_from_sig(
        idx: u32,
        sig: &XmssSig<Self>,
        m: &Array<u8, Self::N>,
        pk_seed: &PkSeed<Self::N>,
        adrs: &address::WotsHash,
    ) -> Array<u8, Self::N>
where {
        let mut adrs = adrs.clone();
        adrs.key_pair_adrs.set(idx);

        let mut node = Self::wots_pk_from_sig(&sig.sig, m, pk_seed, &adrs);

        let mut adrs = adrs.tree_adrs();

        let mut idx = idx;
        let mut rem;
        for j in 0..Self::HPrime::U32 {
            adrs.tree_height.set(j + 1);
            (idx, rem) = (idx >> 1, idx & 1);
            adrs.tree_index.set(idx);
            if rem == 0 {
                node = Self::h(pk_seed, &adrs, &node, &sig.auth[j as usize]);
            } else {
                node = Self::h(pk_seed, &adrs, &sig.auth[j as usize], &node);
            }
        }
        node
    }
}

#[cfg(test)]
mod tests {

    use crate::PkSeed;
    use crate::SkSeed;
    use crate::util::macros::test_parameter_sets;
    use hex_literal::hex;
    use hybrid_array::Array;
    use rand::Rng;
    use rand::RngCore;
    use rand::rng;

    use typenum::Unsigned;

    use crate::{address::WotsHash, hashes::Shake128f, xmss::XmssParams};

    #[test]
    fn test_xmss_node_shake128f_kat() {
        let sk_seed = SkSeed(Array([1; 16]));
        let pk_seed = PkSeed(Array([2; 16]));
        let adrs = WotsHash::default();
        let node = Shake128f::xmss_node(
            &sk_seed,
            0,
            <Shake128f as XmssParams>::HPrime::U32,
            &pk_seed,
            &adrs,
        );

        // Generated by https://github.com/mjosaarinen/slh-dsa-py
        let expected = hex!("94e24679fb2460b97332db131c38bec9");
        assert_eq!(node.as_slice(), expected);
    }

    #[test]
    #[cfg(feature = "alloc")]
    fn test_sign_shake128f_kat() {
        let sk_seed = SkSeed(Array([1; 16]));
        let pk_seed = PkSeed(Array([2; 16]));
        let adrs = WotsHash::default();
        let m = Array([3; 16]);
        let idx = 3;
        let sig = Shake128f::xmss_sign(&m, &sk_seed, &pk_seed, idx, &adrs);

        let expected = hex!(
            "
        a77a0b07e558b023f653a954d886ac66ded67b313f9db7fd93da00686be66a3f
        2e2d3e841292bf5a4060d88509e9a2a51e0bbae6835482bceabce76c5653546d
        08c2f5f78e7491f755f35380d965598891131bdd4c57df2397eed8062a1038fb
        10c758bb30c6ea3859db4eb6296269d170d86cc67804dc63a61e5f30af709aad
        2407624eb81549e87c326c2a646c2b995dfad81cc007286b6f50b56f61352fa2
        752a30aa4f63cc367a7a1c57140a086cc43387ce5f530d84538d0c503d051be2
        9c0040486c2953d34e3817bfcb6f198e545476ddd93930af48333b4e7e0eba03
        3bdbc1badca23875d2f4345699075558a68c8f53865c0b2151208a7a5a4b0c7d
        270b71d5688c6d727525e3fd9c75b9656e13394777faee925fe8cda6e2b7c52a
        684f218679a48b942127f89ffaa069db21659a09266e9304ce870c16094bf585
        6ed93c0748b9479a95d4309c74c2da26b2cf2e5f2090f02601b80c3373b14666
        f0bd973d10c7eb649966d1ffd3e87979899812fef1e23f5703a99924001d9ba9
        522ea93575ad20143eeeeff77b8d192870932b1583459271f634a65441fe1907
        370f71e4d9312b930a66e1b85cba8f4a404c703c7c38ada5c6b95824c2c0ff87
        b1e3f258189d949430c516d2c2192ffbb8d687b10228d7ecf47f86c1299825a8
        b6ee7c560f4bd1720aabdca41c8a5569e9917f906efca17d5f080e65e5a16386
        c9bb4f1ad49404340df212e94d77ff5a25b8649b725e1993dc66f37a89058499
        107bb57a4f699688406e89a44776b95bd1af01290496fb4f3abba58eb407eff9
        c1dfd1362d169170f8b7364c6aa8e6507f049484e5d9b934e86d61b1d3155b5a"
        );

        assert_eq!(sig.to_vec(), expected);
    }

    fn test_sign_verify<Xmss: XmssParams>() {
        // Generate random sk_seed, pk_seed, message, index, address
        let mut rng = rng();

        let sk_seed = SkSeed::new(&mut rng);

        let pk_seed = PkSeed::new(&mut rng);

        let mut msg = Array::<u8, _>::default();
        rng.fill_bytes(msg.as_mut_slice());

        let idx = rng.random_range(0..(1 << Xmss::HPrime::U32));

        let adrs = WotsHash::default();

        let pk = Xmss::xmss_node(&sk_seed, 0, Xmss::HPrime::U32, &pk_seed, &adrs);

        let sig = Xmss::xmss_sign(&msg, &sk_seed, &pk_seed, idx, &adrs);
        let pk_recovered = Xmss::xmss_pk_from_sig(idx, &sig, &msg, &pk_seed, &adrs);

        assert_eq!(pk, pk_recovered);
    }

    test_parameter_sets!(test_sign_verify);

    fn test_sign_verify_fail<Xmss: XmssParams>() {
        // Generate random sk_seed, pk_seed, message, index, address
        let mut rng = rng();

        let sk_seed = SkSeed::new(&mut rng);

        let pk_seed = PkSeed::new(&mut rng);

        let mut msg = Array::<u8, _>::default();
        rng.fill_bytes(msg.as_mut_slice());

        let idx = rng.random_range(0..(1 << Xmss::HPrime::U32));

        let adrs = WotsHash::default();

        let pk = Xmss::xmss_node(&sk_seed, 0, Xmss::HPrime::U32, &pk_seed, &adrs);

        let sig = Xmss::xmss_sign(&msg, &sk_seed, &pk_seed, idx, &adrs);

        // Tweak message
        msg[0] ^= 0xff;

        let pk_recovered = Xmss::xmss_pk_from_sig(idx, &sig, &msg, &pk_seed, &adrs);

        assert_ne!(pk, pk_recovered);
    }

    test_parameter_sets!(test_sign_verify_fail);
}
